
using Flux,
    DiffEqFlux,
    DifferentialEquations,
    XLSX,
    DataFrames,
    DiffEqSensitivity,
    Plots, Optim, OrdinaryDiffEq

using Random, JLD
save("/Users/urielyang/OneDrive - Emory University/Honors/workspace.jld",
    "x_it", x_it, "y_it", y_it, "x_sk", x_sk, "y_sk", y_sk, "x_wh", x_wh, "y_wh", y_wh,
    "pred_it", pred_it, "pred_sk", pred_sk, "pred_wh", pred_wh)


df = DataFrame(XLSX.readdata(
    "/Users/urielyang/OneDrive - Emory University/Honors/Data_Italy.xlsx",
    "Sheet!A1:I85",
))

# keep only infected and recovered
final_data = df[2:end, [5, 6]]
#data passed into the "train" function
global x = convert(Array{Float64,1}, final_data[:, 1])
global y = convert(Array{Float64,1}, final_data[:, 2])
x_sk = x
y_sk = y
x_wh = x
y_wh = y
x_it = x
y_it = y

function SEIR!(du, u, p, t)
    s, e, i, r= u
    n = s + e + i + r
    β = 0.06
    σ = 0.44
    γ = 0.04
    du[1] = ds = -β * s * i / n
    du[2] = de = β * s * i / n - σ * e
    du[3] = di = σ * e - γ * i
    du[4] = dr = γ * i
end

#China 11m, SK 52m, Italy 60m
u0 = [60e6, 10000.0, 500.0, 10.0]
p = [0]

ts = 0.0
tend = 79.0
tspan = (ts, tend)
prob = ODEProblem(SEIR!, u0, tspan, p)

function predict(p)
    return Array(solve(prob, Tsit5(), u0=u0, p=p, saveat = ts:1:tend))
end

prediction = predict(p)
pred_sk = prediction
pred_wh = prediction
pred_it = prediction

#plot the result
t_step = ts:1:tend
scatter(t_step, y[1:80], color = "blue", label = "Recovered", markerstrokecolor = "blue")
scatter!(t_step, x[1:80], color = "red", label = "Infected", markerstrokecolor = "red")

plot!(t_step, prediction[4, :], color = "blue", label = "Predicted recovered")
plot!(t_step, prediction[3, :], color = "red", label = "Predicted infected")
plot!(legendfont = font("Times new roman", 7), legend=:topleft, foreground_color_legend = nothing, background_color_legend = nothing)
plot!(xlabel = "Days post 500 infected", ylabel = "Number of cases")
vline!([40], color = "black", label = "Split for training and testing data", linestyle = :dash)
